{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d6p8EySq1zXZ"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F1xIRPtY0E1w"
      },
      "source": [
        "# Create an Estimator from a Keras model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/estimator/keras_model_to_estimator\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/estimator/keras_model_to_estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/estimator/keras_model_to_estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/estimator/keras_model_to_estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dhcq8Ds4mCtm"
      },
      "source": [
        "> Warning: Estimators are not recommended for new code.  Estimators run `v1.Session`-style code which is more difficult to write correctly, and can behave unexpectedly, especially when combined with TF 2 code. Estimators do fall under our [compatibility guarantees](https://tensorflow.org/guide/versions), but will receive no fixes other than security vulnerabilities. See the [migration guide](https://tensorflow.org/guide/migrate) for details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZaGcclVLwqDS"
      },
      "source": [
        "## Overview\n",
        "\n",
        "TensorFlow Estimators are supported in TensorFlow, and can be created from new and existing `tf.keras` models. This tutorial contains a complete, minimal example of that process.\n",
        "\n",
        "Note: If you have a Keras model, you can use it directly with [`tf.distribute` strategies](https://tensorflow.org/guide/migrate/guide/distributed_training) without converting it to an estimator.  As such, `model_to_estimator` is no longer recommended."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "epgfaZmO2vF0"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qmq4FzaztASN"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9ZUATGJGtQIU"
      },
      "source": [
        "### Create a simple Keras model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rR-zPidHyzcb"
      },
      "source": [
        "In Keras, you assemble *layers* to build *models*. A model is (usually) a graph\n",
        "of layers. The most common type of model is a stack of layers: the\n",
        "`tf.keras.Sequential` model.\n",
        "\n",
        "To build a simple, fully-connected network (i.e. multi-layer perceptron):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p5NSx38itD1a"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.models.Sequential([\n",
        "    tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),\n",
        "    tf.keras.layers.Dropout(0.2),\n",
        "    tf.keras.layers.Dense(3)\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ABgo9-8BtYNs"
      },
      "source": [
        "Compile the model and get a summary."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nViACuBDtVEC"
      },
      "outputs": [],
      "source": [
        "model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "              optimizer='adam')\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pM3Cx5Fm_sHI"
      },
      "source": [
        "### Create an input function\n",
        "\n",
        "Use the [Datasets API](../../guide/data.md) to scale to large datasets\n",
        "or multi-device training.\n",
        "\n",
        "Estimators need control of when and how their input pipeline is built. To allow this, they require an \"Input function\" or `input_fn`. The `Estimator` will call this function with no arguments. The `input_fn` must return a `tf.data.Dataset`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H0DpLEop_x0o"
      },
      "outputs": [],
      "source": [
        "def input_fn():\n",
        "  split = tfds.Split.TRAIN\n",
        "  dataset = tfds.load('iris', split=split, as_supervised=True)\n",
        "  dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))\n",
        "  dataset = dataset.batch(32).repeat()\n",
        "  return dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UR1vRw1bBFjo"
      },
      "source": [
        "Test out your `input_fn`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WO94bGYKBKRv"
      },
      "outputs": [],
      "source": [
        "for features_batch, labels_batch in input_fn().take(1):\n",
        "  print(features_batch)\n",
        "  print(labels_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "svdhkQ4Otcv0"
      },
      "source": [
        "### Create an Estimator from the tf.keras model.\n",
        "\n",
        "A `tf.keras.Model` can be trained with the `tf.estimator` API by converting the\n",
        "model to an `tf.estimator.Estimator` object with\n",
        "`tf.keras.estimator.model_to_estimator`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "roChngg8t7il"
      },
      "outputs": [],
      "source": [
        "import tempfile\n",
        "model_dir = tempfile.mkdtemp()\n",
        "keras_estimator = tf.keras.estimator.model_to_estimator(\n",
        "    keras_model=model, model_dir=model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U-8ekW5It_2w"
      },
      "source": [
        "Train and evaluate the estimator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ouIkVtp9uAg5"
      },
      "outputs": [],
      "source": [
        "keras_estimator.train(input_fn=input_fn, steps=500)\n",
        "eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)\n",
        "print('Eval result: {}'.format(eval_result))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "keras_model_to_estimator.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
